PyTorch 加载数据集
使用 transforms 加载数据集
transforms 是 PyTorch 的 torchvision 库中的一个非常有用的模块,它提供了一系列预处理功能,可以在加载数据时直接应用于数据集。这对于图像数据集特别有用,因为你经常需要进行诸如裁剪、归一化、增强等操作。
以下是一个使用 transforms 加载 CIFAR10 数据集的示例:
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
# 定义转换操作
transform = transforms.Compose([
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.ToTensor(), # 将 PIL 图像转换为张量
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
# 加载数据集并应用转换
train_dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
在上述代码中,我们首先定义了一个转换操作序列,然后在加载 CIFAR10 数据集时将其应用。这样,每次从数据集中提取图像时,都会自动应用这些转换操作。
使用 transforms 可以大大简化数据预处理和增强的工作流程,并确保在训练和评估时都使用相同的预处理步骤。
自定义加载训练集
在 PyTorch 中,加载自定义训练集通常涉及以下几个步骤:
- 定义数据集类:创建一个继承自
torch.utils.data.Dataset的类,并实现__len__和__getitem__方法。 - 创建数据集实例:使用定义的数据集类创建数据集的实例。
- 创建数据加载器:使用
torch.utils.data.DataLoader创建一个数据加载器,以便在训练和评估时批量加载数据。
以下是一个简单的示例,展示如何加载自定义数据集:
1. 定义数据集类
首先,我们需要定义一个数据集类,该类继承自 torch.utils.data.Dataset:
from torch.utils.data import Dataset
from PIL import Image
import os
class CustomDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_files = [f for f in os.listdir(root_dir) if f.endswith('.png')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.image_files[idx])
image = Image.open(img_name)
label = self.image_files[idx].split('_')[1].split('.')[0] # 假设标签是文件名中的某部分
if self.transform:
image = self.transform(image)
return image, label
在上述代码中,我们定义了一个 CustomDataset 类,该类从指定的 root_dir 加载图像,并从文件名中提取标签。
2. 创建数据集实例
使用定义的数据集类创建数据集的实例:
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize((128, 64)),
transforms.ToTensor(),
])
train_dataset = CustomDataset(root_dir='captcha', transform=transform)
3. 创建数据加载器
使用 torch.utils.data.DataLoader 创建数据加载器:
from torch.utils.data import DataLoader
trainloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
现在,你可以在训练循环中使用 trainloader 来迭代数据:
for batch_data, batch_labels in trainloader:
# 进行训练或评估
pass
这就是在 PyTorch 中加载自定义数据集的基本方法。你可以根据自己的需求进一步扩展和自定义这个过程。
batch_size 的设置
batch_size 是一个参数,用于指定从数据集中一次提取的样本数量,以进行一次迭代的训练或评估。
在深度学习中,我们通常不会一次处理整个数据集,因为这样可能会导致内存不足或计算效率低下。相反,我们将数据集分成多个小批次(batch),每个批次包含一定数量的样本。这种方法称为小批次梯度下降(Mini-batch Gradient Descent)。
具体来说,batch_size=64 意味着每次从 train_dataset 中提取 64 个样本进行训练。这样,网络的权重会在每个批次后更新,而不是在整个数据集上进行一次完整的前向和反向传播后更新。
以下是 batch_size 的一些关键点:
- 计算效率:使用小批次可以利用现代硬件(特别是 GPU)的并行处理能力,从而提高计算效率。
- 内存使用:较小的批次可以减少内存使用,使得大型模型和数据集可以在有限的内存中进行训练。
- 收敛速度:与整批次梯度下降相比,小批次梯度下降通常可以更快地收敛,但可能会在达到最小值时出现震荡。
- 泛化性能:由于每次更新都是基于小批次的数据,这为模型提供了一定的随机性,有助于防止过拟合。
选择合适的 batch_size 是一个实验问题,可能需要根据具体的应用和硬件进行调整。
下面举个具体的例子说明
import torch
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
# 1. 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 2. 加载数据集
dataset = CIFAR10(root='./data', train=True, download=True, transform=transform)
# 假设我们只使用600张图片
subset_dataset = torch.utils.data.Subset(dataset, indices=range(600))
# 3. 创建数据加载器
batch_size = 10
dataloader = DataLoader(subset_dataset, batch_size=batch_size, shuffle=True)
# 4. 定义一个简单的模型和损失函数
model = torch.nn.Linear(3 * 32 * 32, 10)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 5. 训练模型
num_epochs = 10
for epoch in range(num_epochs):
for i, (inputs, labels) in enumerate(dataloader):
# 将输入展平
inputs = inputs.view(inputs.size(0), -1)
# 前向传播
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印每个批次的损失
if (i+1) % 20 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}/60], Loss: {loss.item():.4f}")
在上述代码中:
- 我们首先加载 CIFAR10 数据集并选择其中的600张图片。
- 使用
batch_size = 10,所以每个 epoch 有60个批次。 - 我们训练模型10个 epoch,所以整个数据集会被处理10次。
- 在每个 epoch 中,我们迭代60个批次,对每个批次的数据进行前向和反向传播,并更新模型的权重。